package org.cspoker.ai.opponentmodels.weka;

import java.io.*;
import java.util.ArrayList;

import org.cspoker.ai.opponentmodels.weka.instances.InstancesBuilder;

import weka.classifiers.Classifier;
import weka.classifiers.trees.M5P;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class ARFFFile {

	private final String nl = InstancesBuilder.nl;
	private final String path;
	private final Object player;
	private final String name;

	private Writer file;
	private long count = 0;
	private WekaOptions config;
	
	private Instances instances;
	private ArrayList<Prediction> predictions;
	private M5P cl = null;
	
	private boolean echo = false;
	
	public ARFFFile(String path, Object player, String name, String attributes,
			WekaOptions config) throws Exception {
//		if (name.equals("PreFoldCallRaise.arff")) echo = true;
		this.path = path;
		this.player = player;
		this.name = name;
		this.config = config;

		// TODO: false => !config.arffOverwrite()
		file = new BufferedWriter(new FileWriter(path + player + name, false));
		file.write(attributes);
		file.flush();
		
		DataSource source = new DataSource(path + player + name);
	    instances = source.getDataSet();
	    // make it clean
	    instances.delete();
	    
	    predictions = new ArrayList<Prediction>();
	    
	    // initiate accuracies
	    for (int i = 0; i < MAX_DECREASE; i++) {
			accuracies[i] = -1;
		}
	}

//	private double countDataLines() {
//		InputStream is;
//		try {
//			is = new BufferedInputStream(new FileInputStream(path + player + name));
//			byte[] c = new byte[1024];
//			int count = 0;
//			int readChars = 0;
//			boolean startReading = false;
//			while ((readChars = is.read(c)) != -1) {
//				for (int i = 0; i < readChars; ++i) {
//					if (c[i] == '\n' && startReading)
//						++count;
//					else if (!startReading && i >= 4 && c[i - 4] == '@'
//							&& c[i - 3] == 'd' && c[i - 2] == 'a'
//							&& c[i - 1] == 't' && c[i] == 'a')
//						startReading = true;
//				}
//			}
//			is.close();
//			return count + (count > 0 ? -1 : 0);
//		} catch (FileNotFoundException e) {
//			e.printStackTrace();
//		} catch (IOException e) {
//			e.printStackTrace();
//		}
//		return 0;
//	}
//
//	private boolean fileExists() throws FileNotFoundException {
//		return new File(path + player + name).exists();
//	}

	public void close() throws IOException {
		file.close();
	}

	public void write(Instance instance) {
//		System.out.println("Writing instance " + (count +1) + " in file " + name);
		try {
			count++;
			file.write(instance.toString() + nl);
			file.flush();
			instances.add(instance);
			adjustWindow();
		} catch (IOException e) {
			throw new IllegalStateException(e);
		} 
	}
	
	public void addPrediction(Prediction p) {
//		if (echo) System.out.println("Adding " + p);
		for (int i = 0; i < instances.numInstances() - predictions.size()-1; i++)
			predictions.add(null);
		predictions.add(p);
	}
	
	public double getWindowSize() {
		return instances.numInstances();
	}
	
	public double getAccuracy() {
		if (predictions.isEmpty()) return 0.0;
		double truePositive = 0.0;
		double trueNegative = 0.0;
		double falsePositive = 0.0;
		double falseNegative = 0.0;
		for (int i = 0; i < predictions.size(); i++) {
			Prediction p = predictions.get(i);
			if (p != null) {
				truePositive += p.getTruePositive();
				trueNegative += p.getTrueNegative();
				falsePositive += p.getFalsePositive();
				falseNegative += p.getFalseNegative();
			}
		}
		return (trueNegative + truePositive) / 
				(trueNegative + truePositive + falseNegative + falsePositive);
	}
	
	private final int MAX_DECREASE = 20;
	private double[] accuracies = new double[MAX_DECREASE];
	private int currentDecrease = 0;
	
	private boolean decreasingAcc(double accuracy) {
		currentDecrease++;
		if (currentDecrease > MAX_DECREASE) {
			for (int i = 0; i < MAX_DECREASE - 1; i++) {
				accuracies[i] = accuracies[i+1];
			}
			accuracies[MAX_DECREASE - 1] = accuracy;
		} else 
			accuracies[currentDecrease - 1] = accuracy;
		
		double slope = calculateLeastSquaresSlope(accuracies);
		
		return (slope < 0);
	}
	
	private double calculateLeastSquaresSlope(double[] accuracies) {
		double n = accuracies.length;
		double sumY = 0.0;
		double sumX = 0.0;
		double sumXY = 0.0;
		double sumX2 = 0.0;
		for (int i = 0; i < accuracies.length; i++) {
			if (echo) System.out.print(accuracies[i] + ", ");
			if (accuracies[i] != -1) {
				sumY += accuracies[i];
				sumX += i;
				sumXY += i*accuracies[i];
				sumX2 += i * i;
			}
		}
		
		double slope = ((n * sumXY) - (sumX * sumY)) / ((n * sumX2) - (sumX * sumX));
		double intercept = (sumY - (sumX * slope)) / n;
		if (echo) System.out.print("slope: " + slope + ", intercept: " + intercept);
		if (echo) System.out.println("");
		
		return slope;
	}

	private boolean printed = false;
	
	private void adjustWindow() {
		if (cl == null) return;
		double windowSize = instances.numInstances();
		double coverage = windowSize / cl.measureNumRules();
		double accuracy = getAccuracy();
		boolean decreasing = decreasingAcc(accuracy);
		double l;
		if ((coverage < config.getCdLowCoverage()) ||
				(accuracy < config.getCdAccuracy() && decreasing))
			l = Math.round(0.2 *  windowSize);
		else if (coverage > 2 * config.getCdHighCoverage() &&
				accuracy > config.getCdAccuracy())
			l = 2;
		else if (coverage > config.getCdHighCoverage() &&
				accuracy > config.getCdAccuracy())
			l = 1;
		else
			l = 0;
		
		if (echo && !printed) {
			System.out.println("L \t Accuracy \t Coverage \t Instances \t Decreasing");
			printed = true;
		}
		
		if (echo) 
			System.out.println(l + "\t" + accuracy + "\t" + coverage + "\t" + windowSize + "\t" + decreasing);
		
		for (int i = 0; i < l; i++) {
			instances.delete(0);
			if (!predictions.isEmpty())
				predictions.remove(0);
		}
		
//		windowSize = windowSize - l;
//		System.out.println(name + ", " + windowSize + ", l: " + l + ", acc: " + accuracy + ", coverage: " + coverage);
	}
	
	public boolean isModelReady() {
		return count > config.getMinimalLearnExamples();
	}
	
	public long getNrExamples() {
		return count;
	}
	
	public String getName() {
		return name;
	}
	
	public Classifier createModel(String fileName, String attribute, String[] rmAttributes) throws Exception {
//		System.out.println("Creating model for " + player + name);
		Instances data;
		if (config.solveConceptDrift())
			data = instances;
		else {
			DataSource source = new DataSource(path + player + name);
		    data = source.getDataSet();
		}
	    if (rmAttributes.length > 0) {
		    String[] optionsDel = new String[2];
			optionsDel[0] = "-R";                                   
			optionsDel[1] = "";
			for (int i = 0; i < rmAttributes.length; i++) 
				optionsDel[1] += (1+data.attribute(rmAttributes[i]).index()) + ",";     
			Remove remove = new Remove();                         
			remove.setOptions(optionsDel);
		    remove.setInputFormat(data);
		    data = Filter.useFilter(data, remove);
	    }
	    // setting class attribute if the data format does not provide this information
	    // E.g., the XRFF format saves the class attribute information as well
	    if (data.classIndex() == -1)
	    	data.setClass(data.attribute(attribute));
	    
	    // train M5P
	    cl = new M5P();
	    cl.setBuildRegressionTree(true);
	    cl.setUnpruned(false);
	    cl.setUseUnsmoothed(false);
	    // further options...
	    cl.buildClassifier(data);
	    
//	    System.out.println("Number of instances: " + data.numInstances());
//	    System.out.println("Number of measures: " + cl.measureNumRules());
//	    System.out.println(cl);
	    
	    // save model + header
	    if (config.modelPersistency())
	    	SerializationHelper.write(path + "../" + player + fileName + ".model", cl);
	    
	    return cl;
	}
}
